diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index 9c9dd3b5ff..cfab2dde43 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -155,10 +155,12 @@ end Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear() -Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = - arrayref(A, i1) -Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = - arrayset(A, convert(T, x)::T, i1) +Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = arrayref( + A, i1 +) +Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = arrayset( + A, convert(T, x)::T, i1 +) # preserve the specific integer type when indexing device arrays, # to avoid extending 32-bit hardware indices to 64-bit. @@ -166,9 +168,9 @@ Base.to_index(::CuTracedArray, i::Integer) = i # Base doesn't like Integer indices, so we need our own ND get and setindex! routines. # See also: https://github.com/JuliaLang/julia/pull/42289 -Base.@propagate_inbounds Base.getindex( - A::CuTracedArray, I::Union{Integer,CartesianIndex}... -) = A[Base._to_linear_index(A, to_indices(A, I)...)] +Base.@propagate_inbounds Base.getindex(A::CuTracedArray, I::Union{Integer,CartesianIndex}...) = A[Base._to_linear_index( + A, to_indices(A, I)... +)] Base.@propagate_inbounds Base.setindex!( A::CuTracedArray, x, I::Union{Integer,CartesianIndex}... ) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index 7279ba585b..1514dbe79a 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -143,12 +143,10 @@ function trace_for(mod, expr) step = length(range.args) == 3 ? 1 : range.args[3] limit = range.args[end] - body_symbols = ExpressionExplorer.compute_symbols_state( - quote - $(Expr(:local, assign)) - $body - end, - ) + body_symbols = ExpressionExplorer.compute_symbols_state(quote + $(Expr(:local, assign)) + $body + end) external_syms = body_symbols.assignments ∪ body_symbols.references filter!(∉(SPECIAL_SYMBOLS), external_syms) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index 511867c649..f64bc46a1d 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -284,7 +284,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N end # TODO is there any way to allocate an uninitialized buffer in XLA? -function Base.similar(a::ConcreteRArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S} +function Base.similar(a::ConcreteRArray{T}, (::Type{S})=T, dims::Dims=size(a)) where {T,S} return ConcreteRArray( Array{S}(undef, dims); client=XLA.client(a.data), device=XLA.device(a.data) ) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 449f78de38..4f034d29f5 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -64,10 +64,10 @@ end ReactantCacheToken(), REACTANT_METHOD_TABLE, world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# + false, + false, + false, + false, set_reactant_abi, ) end @@ -81,10 +81,10 @@ else REACTANT_CACHE, REACTANT_METHOD_TABLE, world, - false, #=forward_rules=# - false, #=reverse_rules=# - false, #=inactive_rules=# - false, #=broadcast_rewrite=# + false, + false, + false, + false, set_reactant_abi, ) end @@ -97,20 +97,25 @@ const enzyme_dupnoneed = 3 const enzyme_outnoneed = 4 const enzyme_constnoneed = 5 -@inline act_from_type(x, reverse, needs_primal=true) = - throw(AssertionError("Unhandled activity $(typeof(x))")) -@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = - act_from_type(Enzyme.Const, reverse, needs_primal) -@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = - act_from_type(Enzyme.Duplicated, reverse, needs_primal) +@inline act_from_type(x, reverse, needs_primal=true) = throw( + AssertionError("Unhandled activity $(typeof(x))") +) +@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = act_from_type( + Enzyme.Const, reverse, needs_primal +) +@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = act_from_type( + Enzyme.Duplicated, reverse, needs_primal +) @inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) = reverse ? enzyme_out : enzyme_dupnoneed -@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) = - act_from_type(Enzyme.Duplicated, reverse, needs_primal) +@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) = act_from_type( + Enzyme.Duplicated, reverse, needs_primal +) @inline act_from_type(::Enzyme.BatchDuplicatedNoNeed, reverse, needs_primal=true) = reverse ? enzyme_out : enzyme_dupnoneed -@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = - act_from_type(Enzyme.Active, reverse, needs_primal) +@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = act_from_type( + Enzyme.Active, reverse, needs_primal +) @inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) = if needs_primal enzyme_const @@ -132,10 +137,12 @@ const enzyme_constnoneed = 5 end end -@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) = - act_from_type(Enzyme.Duplicated, reverse, needs_primal) -@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) = - act_from_type(Enzyme.DuplicatedNoNeed, Reverse, needs_primal) +@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) = act_from_type( + Enzyme.Duplicated, reverse, needs_primal +) +@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) = act_from_type( + Enzyme.DuplicatedNoNeed, Reverse, needs_primal +) @inline act_from_type(::Type{<:Enzyme.Active}, reverse, needs_primal) = if needs_primal @@ -465,7 +472,7 @@ function overload_autodiff( false, TracedUtils.transpose_val(MLIR.IR.result(res, residx)); emptypaths=true, - ) #=reverse=# + ) residx += 1 continue end diff --git a/src/Overlay.jl b/src/Overlay.jl index 5d9b85c838..6204f87cdb 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -80,7 +80,7 @@ for randfun in (:rand, :randn, :randexp) # scalars @reactant_overlay @noinline function Random.$(randfun)( - rng::AbstractRNG, ::Type{T}=Float64 + rng::AbstractRNG, (::Type{T})=Float64 ) where {T} if T <: ReactantPrimitive return TracedRandom.$(overload_randfun)(rng, T) diff --git a/src/Reactant.jl b/src/Reactant.jl index 14807f3f81..0233cae411 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -217,7 +217,7 @@ include("Compiler.jl") include("Overlay.jl") function Enzyme.make_zero( - ::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false) + ::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false) )::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}} if haskey(seen, prev) return seen[prev] diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 4c05527cb6..6790bcb991 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -345,9 +345,8 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where { end indices = [ - ( - TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1 - ).mlir_data for i in indices + (TracedUtils.promote_to(TracedRNumber{Int}, i isa Colon ? 1 : first(i)) - 1).mlir_data + for i in indices ] res = MLIR.IR.result( MLIR.Dialects.stablehlo.dynamic_update_slice( diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index ce21a2fad4..5c77c49b71 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -134,7 +134,7 @@ for (jlop, hloop, hlocomp) in ( function $(jlop)( @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T}) ) where {T} - return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp)) + return Ops.compare(lhs, rhs; comparison_direction=($(hlocomp))) end function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 0a4e218e0a..47a22d219b 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -35,17 +35,17 @@ function materialize_traced_array( end get_mlir_data(x::TracedRNumber) = x.mlir_data -set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data = data; return x) +set_mlir_data!(x::TracedRNumber, data) = (x.mlir_data=data; return x) get_paths(x::TracedRNumber) = x.paths -set_paths!(x::TracedRNumber, paths) = (x.paths = paths; return x) +set_paths!(x::TracedRNumber, paths) = (x.paths=paths; return x) get_mlir_data(x::TracedRArray) = x.mlir_data get_mlir_data(x::AnyTracedRArray) = get_mlir_data(materialize_traced_array(x)) get_paths(x::TracedRArray) = x.paths -set_paths!(x::TracedRArray, paths) = (x.paths = paths; return x) +set_paths!(x::TracedRArray, paths) = (x.paths=paths; return x) get_paths(x::MissingTracedValue) = x.paths -set_paths!(x::MissingTracedValue, paths) = (x.paths = paths; return x) +set_paths!(x::MissingTracedValue, paths) = (x.paths=paths; return x) function set_mlir_data!(x::TracedRArray, data) x.mlir_data = data diff --git a/src/Tracing.jl b/src/Tracing.jl index a439aa93f7..e79f2ce150 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -207,15 +207,21 @@ Base.@nospecializeinfer function traced_type_inner( end end -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typet(TV.body)) -Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = - TracedRArray{TV.parameters...} - -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = - UnionAll(TV.var, base_typec(TV.body)) -Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = - (TV <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} +Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::UnionAll)) = UnionAll( + TV.var, base_typet(TV.body) +) +Base.@nospecializeinfer @inline base_typet(@nospecialize(TV::DataType)) = TracedRArray{ + TV.parameters... +} + +Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::UnionAll)) = UnionAll( + TV.var, base_typec(TV.body) +) +Base.@nospecializeinfer @inline base_typec(@nospecialize(TV::DataType)) = ( + TV<:TracedRArray ? ConcreteRArray : ConcreteRNumber +){ + TV.parameters... +} Base.@nospecializeinfer function traced_type_inner( @nospecialize(T::Type{<:ConcreteRArray}), diff --git a/src/mlir/IR/AffineExpr.jl b/src/mlir/IR/AffineExpr.jl index 7a0f4931d6..4519ceac69 100644 --- a/src/mlir/IR/AffineExpr.jl +++ b/src/mlir/IR/AffineExpr.jl @@ -56,8 +56,9 @@ ismultipleof(expr::AffineExpr, factor) = API.mlirAffineExprIsMultipleOf(expr, fa Checks whether the given affine expression involves AffineDimExpr 'position'. """ -isfunctionofdimexpr(expr::AffineExpr, position) = +function isfunctionofdimexpr(expr::AffineExpr, position) API.mlirAffineExprIsFunctionOfDim(expr, position) +end """ isdimexpr(affineExpr) @@ -71,8 +72,9 @@ isdimexpr(expr::AffineExpr) = API.mlirAffineExprIsADim(expr) Creates an affine dimension expression with 'position' in the context. """ -AffineDimensionExpr(position; context::Context=context()) = +function AffineDimensionExpr(position; context::Context=context()) AffineExpr(API.mlirAffineDimExprGet(context, position)) +end """ issymbolexpr(affineExpr) @@ -86,8 +88,9 @@ issymbolexpr(expr::AffineExpr) = API.mlirAffineExprIsASymbol(expr) Creates an affine symbol expression with 'position' in the context. """ -SymbolExpr(position; context::Context=context()) = +function SymbolExpr(position; context::Context=context()) AffineExpr(API.mlirAffineSymbolExprGet(context, position)) +end """ position(affineExpr) @@ -120,8 +123,9 @@ isconstantexpr(expr::AffineExpr) = API.mlirAffineExprIsAConstant(expr) Creates an affine constant expression with 'constant' in the context. """ -ConstantExpr(constant; context::Context=context()) = +function ConstantExpr(constant; context::Context=context()) AffineExpr(API.mlirAffineConstantExprGet(context, constant)) +end """ value(affineExpr) @@ -189,8 +193,9 @@ isfloordiv(expr::AffineExpr) = API.mlirAffineExprIsAFloorDiv(expr) Creates an affine floordiv expression with 'lhs' and 'rhs'. """ -Base.div(lhs::AffineExpr, rhs::AffineExpr) = +function Base.div(lhs::AffineExpr, rhs::AffineExpr) AffineExpr(API.mlirAffineFloorDivExprGet(lhs, rhs)) +end Base.fld(lhs::AffineExpr, rhs::AffineExpr) = div(lhs, rhs) """ @@ -205,8 +210,9 @@ isceildiv(expr::AffineExpr) = API.mlirAffineExprIsACeilDiv(expr) Creates an affine ceildiv expression with 'lhs' and 'rhs'. """ -Base.cld(lhs::AffineExpr, rhs::AffineExpr) = +function Base.cld(lhs::AffineExpr, rhs::AffineExpr) AffineExpr(API.mlirAffineCeilDivExprGet(lhs, rhs)) +end """ isbinary(affineExpr) diff --git a/src/mlir/IR/AffineMap.jl b/src/mlir/IR/AffineMap.jl index 1a36f724d7..f0e07adf06 100644 --- a/src/mlir/IR/AffineMap.jl +++ b/src/mlir/IR/AffineMap.jl @@ -44,8 +44,9 @@ context(map::AffineMap) = API.mlirAffineMapGetContext(map) Creates a zero result affine map of the given dimensions and symbols in the context. The affine map is owned by the context. """ -AffineMap(ndims, nsymbols; context::Context=context()) = +function AffineMap(ndims, nsymbols; context::Context=context()) AffineMap(API.mlirAffineMapZeroResultGet(context, ndims, nsymbols)) +end """ AffineMap(ndims, nsymbols, affineExprs; context=context()) @@ -53,24 +54,27 @@ AffineMap(ndims, nsymbols; context::Context=context()) = Creates an affine map with results defined by the given list of affine expressions. The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results. """ -AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) = +function AffineMap(ndims, nsymbols, exprs::Vector{AffineExpr}; context::Context=context()) AffineMap(API.mlirAffineMapGet(context, ndims, nsymbols, length(exprs), exprs)) +end """ ConstantAffineMap(val; context=context()) Creates a single constant result affine map in the context. The affine map is owned by the context. """ -ConstantAffineMap(val; context::Context=context()) = +function ConstantAffineMap(val; context::Context=context()) AffineMap(API.mlirAffineMapConstantGet(context, val)) +end """ IdentityAffineMap(ndims; context=context()) Creates an affine map with 'ndims' identity in the context. The affine map is owned by the context. """ -IdentityAffineMap(ndims; context::Context=context()) = +function IdentityAffineMap(ndims; context::Context=context()) AffineMap(API.mlirAffineMapMultiDimIdentityGet(context, ndims)) +end """ MinorIdentityAffineMap(ndims, nresults; context=context()) @@ -189,8 +193,9 @@ Base.isperm(map::AffineMap) = API.mlirAffineMapIsPermutation(map) Returns the affine map consisting of the `positions` subset. """ -submap(map::AffineMap, pos::Vector{Int}) = +function submap(map::AffineMap, pos::Vector{Int}) AffineMap(API.mlirAffineMapGetSubMap(map, length(pos), pos)) +end """ majorsubmap(affineMap, nresults) @@ -199,8 +204,9 @@ Returns the affine map consisting of the most major `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero. Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map. """ -majorsubmap(map::AffineMap, nresults) = +function majorsubmap(map::AffineMap, nresults) AffineMap(API.mlirAffineMapGetMajorSubMap(map, nresults)) +end """ minorsubmap(affineMap, nresults) @@ -208,19 +214,24 @@ majorsubmap(map::AffineMap, nresults) = Returns the affine map consisting of the most minor `nresults` results. Returns the null AffineMap if the `nresults` is equal to zero. Returns the `affineMap` if `nresults` is greater or equals to number of results of the given affine map. """ -minorsubmap(map::AffineMap, nresults) = +function minorsubmap(map::AffineMap, nresults) AffineMap(API.mlirAffineMapGetMinorSubMap(map, nresults)) +end """ mlirAffineMapReplace(affineMap, expression => replacement, numResultDims, numResultSyms) Apply `AffineExpr::replace(map)` to each of the results and return a new new AffineMap with the new results and the specified number of dims and symbols. """ -Base.replace( +function Base.replace( map::AffineMap, old_new::Pair{AffineExpr,AffineExpr}, nresultdims, nresultsyms -) = AffineMap( - API.mlirAffineMapReplace(map, old_new.first, old_new.second, nresultdims, nresultsyms), ) + AffineMap( + API.mlirAffineMapReplace( + map, old_new.first, old_new.second, nresultdims, nresultsyms + ), + ) +end """ simplify(affineMaps, size, result, populateResult) diff --git a/src/mlir/IR/Attribute.jl b/src/mlir/IR/Attribute.jl index d7aac00830..9ad40d8340 100644 --- a/src/mlir/IR/Attribute.jl +++ b/src/mlir/IR/Attribute.jl @@ -16,8 +16,9 @@ Base.convert(::Core.Type{API.MlirAttribute}, attribute::Attribute) = attribute.a Parses an attribute. The attribute is owned by the context. """ -Base.parse(::Core.Type{Attribute}, str; context::Context=context()) = +function Base.parse(::Core.Type{Attribute}, str; context::Context=context()) Attribute(API.mlirAttributeParseGet(context, str)) +end """ ==(a1, a2) @@ -80,8 +81,9 @@ isarray(attr::Attribute) = API.mlirAttributeIsAArray(attr) Creates an array element containing the given list of elements in the given context. """ -Attribute(attrs::Vector{Attribute}; context::Context=context()) = +function Attribute(attrs::Vector{Attribute}; context::Context=context()) Attribute(API.mlirArrayAttrGet(context, length(attrs), attrs)) +end """ isdict(attr) @@ -168,8 +170,9 @@ isinteger(attr::Attribute) = API.mlirAttributeIsAInteger(attr) Creates an integer attribute of the given type with the given integer value. """ -Attribute(i::T, type=Type(T)) where {T<:Integer} = +function Attribute(i::T, type=Type(T)) where {T<:Integer} Attribute(API.mlirIntegerAttrGet(type, Int64(i))) +end """ Int64(attr) @@ -237,8 +240,9 @@ isopaque(attr::Attribute) = API.mlirAttributeIsAOpaque(attr) Creates an opaque attribute in the given context associated with the dialect identified by its namespace. The attribute contains opaque byte data of the specified length (data need not be null-terminated). """ -OpaqueAttribute(namespace, data, type; context::Context=context) = +function OpaqueAttribute(namespace, data, type; context::Context=context) Attribute(API.mlirOpaqueAttrGet(context, namespace, length(data), data, type)) +end """ mlirOpaqueAttrGetDialectNamespace(attr) @@ -272,8 +276,9 @@ isstring(attr::Attribute) = API.mlirAttributeIsAString(attr) Creates a string attribute in the given context containing the given string. """ -Attribute(str::AbstractString; context::Context=context()) = +function Attribute(str::AbstractString; context::Context=context()) Attribute(API.mlirStringAttrGet(context, str)) +end """ Attribute(type, str) @@ -307,9 +312,11 @@ issymbolref(attr::Attribute) = API.mlirAttributeIsASymbolRef(attr) Creates a symbol reference attribute in the given context referencing a symbol identified by the given string inside a list of nested references. Each of the references in the list must not be nested. """ -SymbolRefAttribute( +function SymbolRefAttribute( symbol::String, references::Vector{Attribute}; context::Context=context() -) = Attribute(API.mlirSymbolRefAttrGet(context, symbol, length(references), references)) +) + Attribute(API.mlirSymbolRefAttrGet(context, symbol, length(references), references)) +end """ rootref(attr) @@ -353,8 +360,9 @@ isflatsymbolref(attr::Attribute) = API.mlirAttributeIsAFlatSymbolRef(attr) Creates a flat symbol reference attribute in the given context referencing a symbol identified by the given string. """ -FlatSymbolRefAttribute(symbol::String; context::Context=context()) = +function FlatSymbolRefAttribute(symbol::String; context::Context=context()) Attribute(API.mlirFlatSymbolRefAttrGet(context, symbol)) +end """ flatsymbol(attr) @@ -675,20 +683,27 @@ issparseelements(attr::Attribute) = API.mlirAttributeIsASparseElements(attr) """ function isdensearray end -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Bool}) = - API.mlirAttributeIsADenseBoolArray(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int8}) = - API.mlirAttributeIsADenseI8Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int16}) = - API.mlirAttributeIsADenseI16Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int32}) = - API.mlirAttributeIsADenseI32Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int64}) = - API.mlirAttributeIsADenseI64Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float32}) = - API.mlirAttributeIsADenseF32Array(attr) -@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float64}) = - API.mlirAttributeIsADenseF64Array(attr) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Bool}) = API.mlirAttributeIsADenseBoolArray( + attr +) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int8}) = API.mlirAttributeIsADenseI8Array( + attr +) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int16}) = API.mlirAttributeIsADenseI16Array( + attr +) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int32}) = API.mlirAttributeIsADenseI32Array( + attr +) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Int64}) = API.mlirAttributeIsADenseI64Array( + attr +) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float32}) = API.mlirAttributeIsADenseF32Array( + attr +) +@llvmversioned min = v"16" isdensearray(attr::Attribute, ::Core.Type{Float64}) = API.mlirAttributeIsADenseF64Array( + attr +) @llvmversioned min = v"16" """ DenseArrayAttribute(array; context=context()) @@ -697,34 +712,32 @@ function isdensearray end """ function DenseArrayAttribute end -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{Bool}; context::Context=context() -) = Attribute( +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Bool}; context::Context=context()) = Attribute( API.mlirDenseBoolArrayGet( context, length(values), AbstractArray{Cint}(to_row_major(values)) ), ) -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{Int8}; context::Context=context() -) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values))) -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{UInt8}; context::Context=context() -) = Attribute(API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values))) -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{Int16}; context::Context=context() -) = Attribute(API.mlirDenseI16ArrayGet(context, length(values), to_row_major(values))) -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{Int32}; context::Context=context() -) = Attribute(API.mlirDenseI32ArrayGet(context, length(values), to_row_major(values))) -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{Int64}; context::Context=context() -) = Attribute(API.mlirDenseI64ArrayGet(context, length(values), to_row_major(values))) -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{Float32}; context::Context=context() -) = Attribute(API.mlirDenseF32ArrayGet(context, length(values), to_row_major(values))) -@llvmversioned min = v"16" DenseArrayAttribute( - values::AbstractArray{Float64}; context::Context=context() -) = Attribute(API.mlirDenseF64ArrayGet(context, length(values), to_row_major(values))) +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int8}; context::Context=context()) = Attribute( + API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values)) +) +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{UInt8}; context::Context=context()) = Attribute( + API.mlirDenseI8ArrayGet(context, length(values), to_row_major(values)) +) +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int16}; context::Context=context()) = Attribute( + API.mlirDenseI16ArrayGet(context, length(values), to_row_major(values)) +) +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int32}; context::Context=context()) = Attribute( + API.mlirDenseI32ArrayGet(context, length(values), to_row_major(values)) +) +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Int64}; context::Context=context()) = Attribute( + API.mlirDenseI64ArrayGet(context, length(values), to_row_major(values)) +) +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Float32}; context::Context=context()) = Attribute( + API.mlirDenseF32ArrayGet(context, length(values), to_row_major(values)) +) +@llvmversioned min = v"16" DenseArrayAttribute(values::AbstractArray{Float64}; context::Context=context()) = Attribute( + API.mlirDenseF64ArrayGet(context, length(values), to_row_major(values)) +) @llvmversioned min = v"16" Attribute(values::AbstractArray) = DenseArrayAttribute(values) diff --git a/src/mlir/IR/Block.jl b/src/mlir/IR/Block.jl index 5b16edd002..8c6c94e08c 100644 --- a/src/mlir/IR/Block.jl +++ b/src/mlir/IR/Block.jl @@ -82,8 +82,9 @@ end Appends an argument of the specified type to the block. Returns the newly added argument. """ -push_argument!(block::Block, type; location::Location=Location()) = +function push_argument!(block::Block, type; location::Location=Location()) Value(API.mlirBlockAddArgument(block, type, location)) +end """ first_op(block) diff --git a/src/mlir/IR/ExecutionEngine.jl b/src/mlir/IR/ExecutionEngine.jl index 76395e9645..16bb8c9dea 100644 --- a/src/mlir/IR/ExecutionEngine.jl +++ b/src/mlir/IR/ExecutionEngine.jl @@ -56,5 +56,6 @@ end Dump as an object in `fileName`. """ -Base.write(filename::String, jit::ExecutionEngine) = +function Base.write(filename::String, jit::ExecutionEngine) API.mlirExecutionEngineDumpToObjectFile(jit, filename) +end diff --git a/src/mlir/IR/Identifier.jl b/src/mlir/IR/Identifier.jl index 2f58836b65..166bbc7a96 100644 --- a/src/mlir/IR/Identifier.jl +++ b/src/mlir/IR/Identifier.jl @@ -7,8 +7,9 @@ end Gets an identifier with the given string value. """ -Identifier(str::String; context::Context=context()) = +function Identifier(str::String; context::Context=context()) Identifier(API.mlirIdentifierGet(context, str)) +end Base.convert(::Core.Type{API.MlirIdentifier}, id::Identifier) = id.identifier diff --git a/src/mlir/IR/IntegerSet.jl b/src/mlir/IR/IntegerSet.jl index eb57939cc8..0b91156ce6 100644 --- a/src/mlir/IR/IntegerSet.jl +++ b/src/mlir/IR/IntegerSet.jl @@ -12,8 +12,9 @@ end Gets or creates a new canonically empty integer set with the give number of dimensions and symbols in the given context. """ -IntegerSet(ndims, nsymbols; context::Context=context()) = +function IntegerSet(ndims, nsymbols; context::Context=context()) IntegerSet(API.mlirIntegerSetEmptyGet(context, ndims, nsymbols)) +end """ IntegerSet(ndims, nsymbols, constraints, eqflags; context=context()) @@ -22,11 +23,13 @@ Gets or creates a new integer set in the given context. The set is defined by a list of affine constraints, with the given number of input dimensions and symbols, which are treated as either equalities (eqflags is 1) or inequalities (eqflags is 0). Both `constraints` and `eqflags` need to be arrays of the same length. """ -IntegerSet(ndims, nsymbols, constraints, eqflags; context::Context=context()) = IntegerSet( - API.mlirIntegerSetGet( - context, ndims, nsymbols, length(constraints), constraints, eqflags - ), -) +function IntegerSet(ndims, nsymbols, constraints, eqflags; context::Context=context()) + IntegerSet( + API.mlirIntegerSetGet( + context, ndims, nsymbols, length(constraints), constraints, eqflags + ), + ) +end """ mlirIntegerSetReplaceGet(set, dimReplacements, symbolReplacements, numResultDims, numResultSymbols) @@ -35,15 +38,17 @@ Gets or creates a new integer set in which the values and dimensions of the give `dimReplacements` and `symbolReplacements` are expected to point to at least as many consecutive expressions as the given set has dimensions and symbols, respectively. The new set will have `numResultDims` and `numResultSymbols` dimensions and symbols, respectively. """ -Base.replace(set::IntegerSet, dim_replacements, symbol_replacements) = IntegerSet( - API.mlirIntegerSetReplaceGet( - set, - dim_replacements, - symbol_replacements, - length(dim_replacements), - length(symbol_replacements), - ), -) +function Base.replace(set::IntegerSet, dim_replacements, symbol_replacements) + IntegerSet( + API.mlirIntegerSetReplaceGet( + set, + dim_replacements, + symbol_replacements, + length(dim_replacements), + length(symbol_replacements), + ), + ) +end Base.convert(::Core.Type{API.MlirIntegerSet}, set::IntegerSet) = set.set diff --git a/src/mlir/IR/Module.jl b/src/mlir/IR/Module.jl index 12794b30ba..87950f5e2a 100644 --- a/src/mlir/IR/Module.jl +++ b/src/mlir/IR/Module.jl @@ -23,8 +23,9 @@ Base.convert(::Core.Type{API.MlirModule}, module_::Module) = module_.module_ Parses a module from the string and transfers ownership to the caller. """ -Base.parse(::Core.Type{Module}, module_; context::Context=context()) = +function Base.parse(::Core.Type{Module}, module_; context::Context=context()) Module(API.mlirModuleCreateParse(context, module_)) +end macro mlir_str(code) quote diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 6f45bbf8ec..c4374b86ad 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -65,8 +65,9 @@ block(operation::Operation) = Block(API.mlirOperationGetBlock(operation), false) Gets the operation that owns this operation, returning null if the operation is not owned. """ -parent_op(operation::Operation) = +function parent_op(operation::Operation) Operation(API.mlirOperationGetParentOperation(operation), false) +end """ rmfromparent!(op) @@ -207,8 +208,9 @@ end Removes an attribute by name. Returns false if the attribute was not found and true if removed. """ -rmattr!(operation::Operation, name) = +function rmattr!(operation::Operation, name) API.mlirOperationRemoveAttributeByName(operation, name) +end function lose_ownership!(operation::Operation) @assert operation.owned @@ -271,8 +273,9 @@ end Returns whether the given fully-qualified operation (i.e. 'dialect.operation') is registered with the context. This will return true if the dialect is loaded and the operation is registered within the dialect. """ -is_registered(opname; context::Context=context()) = +function is_registered(opname; context::Context=context()) API.mlirContextIsRegisteredOperation(context, opname) +end function create_operation( name, diff --git a/src/mlir/IR/Pass.jl b/src/mlir/IR/Pass.jl index 94deb66e5d..0017c0e007 100644 --- a/src/mlir/IR/Pass.jl +++ b/src/mlir/IR/Pass.jl @@ -30,8 +30,9 @@ PassManager(; context::Context=context()) = PassManager(API.mlirPassManagerCreat Create a new top-level PassManager anchored on `anchorOp`. """ -PassManager(anchor_op::Operation; context::Context=context()) = +function PassManager(anchor_op::Operation; context::Context=context()) PassManager(API.mlirPassManagerCreateOnOperation(context, anchor_op)) +end Base.convert(::Core.Type{API.MlirPassManager}, pass::PassManager) = pass.pass @@ -96,8 +97,9 @@ end Cast a top-level `PassManager` to a generic `OpPassManager`. """ -OpPassManager(pm::PassManager) = +function OpPassManager(pm::PassManager) OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) +end """ OpPassManager(passManager, operationName) @@ -105,16 +107,18 @@ OpPassManager(pm::PassManager) = Nest an `OpPassManager` under the top-level PassManager, the nested passmanager will only run on operations matching the provided name. The returned `OpPassManager` will be destroyed when the parent is destroyed. To further nest more `OpPassManager` under the newly returned one, see `mlirOpPassManagerNest` below. """ -OpPassManager(pm::PassManager, opname) = +function OpPassManager(pm::PassManager, opname) OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) +end """ OpPassManager(opPassManager, operationName) Nest an `OpPassManager` under the provided `OpPassManager`, the nested passmanager will only run on operations matching the provided name. The returned `OpPassManager` will be destroyed when the parent is destroyed. """ -OpPassManager(opm::OpPassManager, opname) = +function OpPassManager(opm::OpPassManager, opname) OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) +end Base.convert(::Core.Type{API.MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass diff --git a/src/mlir/IR/Region.jl b/src/mlir/IR/Region.jl index 38d38c8fef..985d2f46dd 100644 --- a/src/mlir/IR/Region.jl +++ b/src/mlir/IR/Region.jl @@ -59,16 +59,18 @@ end Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, prepends the block to the region. """ -insert_after!(region::Region, reference::Block, block::Block) = +function insert_after!(region::Region, reference::Block, block::Block) API.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) +end """ insert_before!(region, reference, block) Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, appends the block to the region. """ -insert_before!(region::Region, reference::Block, block::Block) = +function insert_before!(region::Region, reference::Block, block::Block) API.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) +end """ first_block(region) diff --git a/src/mlir/IR/Type.jl b/src/mlir/IR/Type.jl index 3ac576e39c..cd32a8293f 100644 --- a/src/mlir/IR/Type.jl +++ b/src/mlir/IR/Type.jl @@ -14,8 +14,9 @@ Base.convert(::Core.Type{API.MlirType}, type::Type) = type.type Parses a type. The type is owned by the context. """ -Base.parse(::Core.Type{Type}, s; context::Context=context()) = +function Base.parse(::Core.Type{Type}, s; context::Context=context()) Type(API.mlirTypeParseGet(context, s)) +end """ ==(t1, t2) @@ -73,8 +74,9 @@ isindex(type::Type) = API.mlirTypeIsAIndex(type) Creates a 1-bit signless integer type in the context. The type is owned by the context. """ -Type(::Core.Type{Bool}; context::Context=context()) = +function Type(::Core.Type{Bool}; context::Context=context()) Type(API.mlirIntegerTypeGet(context, 1)) +end # Integer types """ @@ -82,24 +84,27 @@ Type(::Core.Type{Bool}; context::Context=context()) = Creates a signless integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Integer}; context::Context=context()) = +function Type(T::Core.Type{<:Integer}; context::Context=context()) Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +end """ Type(T::Core.Type{<:Signed}; context=context() Creates a signed integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Signed}; context::Context=context()) = +function Type(T::Core.Type{<:Signed}; context::Context=context()) Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +end """ Type(T::Core.Type{<:Unsigned}; context=context() Creates an unsigned integer type of the given bitwidth in the context. The type is owned by the context. """ -Type(T::Core.Type{<:Unsigned}; context::Context=context()) = +function Type(T::Core.Type{<:Unsigned}; context::Context=context()) Type(API.mlirIntegerTypeUnsignedGet(context, sizeof(T) * 8)) +end """ isinteger(type) @@ -626,8 +631,9 @@ end Creates a tuple type that consists of the given list of elemental types. The type is owned by the context. """ -Type(elements::Vector{Type}; context::Context=context()) = +function Type(elements::Vector{Type}; context::Context=context()) Type(API.mlirTupleTypeGet(context, length(elements), elements)) +end function Type(@nospecialize(elements::NTuple{N,Type}); context::Context=context()) where {N} return Type(collect(elements); context) end @@ -712,8 +718,9 @@ end Creates an opaque type in the given context associated with the dialect identified by its namespace. The type contains opaque byte data of the specified length (data need not be null-terminated). """ -OpaqueType(namespace, data; context::Context=context()) = +function OpaqueType(namespace, data; context::Context=context()) Type(API.mlirOpaqueTypeGet(context, namespace, data)) +end """ isopaque(type) diff --git a/src/stdlibs/Base.jl b/src/stdlibs/Base.jl index dd13f195ce..a42686ea90 100644 --- a/src/stdlibs/Base.jl +++ b/src/stdlibs/Base.jl @@ -1,4 +1,6 @@ -@inline Base.vcat(a::Number, b::Union{AnyConcreteRArray,AnyTracedRArray}) = - @allowscalar(vcat(fill!(similar(b, typeof(a), (1, size(b)[2:end]...)), a), b)) -@inline Base.hcat(a::Number, b::Union{AnyConcreteRArray,AnyTracedRArray}) = - @allowscalar(hcat(fill!(similar(b, typeof(a), (size(b)[1:(end - 1)]..., 1)), a), b)) +@inline Base.vcat(a::Number, b::Union{AnyConcreteRArray,AnyTracedRArray}) = @allowscalar( + vcat(fill!(similar(b, typeof(a), (1, size(b)[2:end]...)), a), b) +) +@inline Base.hcat(a::Number, b::Union{AnyConcreteRArray,AnyTracedRArray}) = @allowscalar( + hcat(fill!(similar(b, typeof(a), (size(b)[1:(end - 1)]..., 1)), a), b) +) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 19e94b2054..5f4be7f859 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -51,7 +51,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) - indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(comp)) + indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=($(comp))) return Ops.select(indicator, parent(x), zero(parent(x))) end @@ -134,8 +134,10 @@ for (AT, dcomp, ocomp) in ( m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) - data_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(dcomp)) - original_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(ocomp)) + data_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=($(dcomp))) + original_indicator = Ops.compare( + row_idxs, col_idxs; comparison_direction=($(ocomp)) + ) res = Ops.add( Ops.select(data_indicator, tdata, z), Ops.select(original_indicator, x.data, z) ) diff --git a/src/stdlibs/Random.jl b/src/stdlibs/Random.jl index 617f1fac19..6b4e77d671 100644 --- a/src/stdlibs/Random.jl +++ b/src/stdlibs/Random.jl @@ -19,8 +19,9 @@ using ..Reactant: unwrapped_eltype using Random: Random, AbstractRNG -@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) = - Random.rand!(rng, Vector{UInt64}(undef, 2)) +@noinline make_seed(rng::AbstractRNG=Random.RandomDevice()) = Random.rand!( + rng, Vector{UInt64}(undef, 2) +) @noinline function Random.seed!(rng::TracedRNG, seed::Number) if seed isa TracedRNumber @@ -142,7 +143,9 @@ for randfun in (:rand, :randn, :randexp) end # scalars - @noinline function $(overload_randfun)(rng::TracedRNG, ::Type{T}=Float64) where {T} + @noinline function $(overload_randfun)( + rng::TracedRNG, (::Type{T})=Float64 + ) where {T} A = TracedUtils.promote_to(TracedRArray{T,0}, fill(T(0))) $(overload_randfun!)(rng, A) return TracedRNumber{T}((), A.mlir_data) diff --git a/src/utils.jl b/src/utils.jl index 4b909757b6..7d304a2c0d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -743,8 +743,6 @@ function call_with_reactant_generator( push!(overdubbed_code, Core.ReturnNode(ocres)) push!(overdubbed_codelocs, code_info.codelocs[1]) - #=== set `code_info`/`reflection` fields accordingly ===# - if code_info.method_for_inference_limit_heuristics === nothing code_info.method_for_inference_limit_heuristics = method end diff --git a/test/control_flow.jl b/test/control_flow.jl index dae8aeb8a9..d32c0c8c7e 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -460,7 +460,7 @@ function condition_with_structure(x) @trace if sum(y) > 0 z = (; a=y, b=(y .- 1, y)) else - z = (; a=-y, b=(y, y .+ 1)) + z = (; a=(-y), b=(y, y .+ 1)) end return z end diff --git a/test/indexing.jl b/test/indexing.jl index ca6a5ccb01..9dbfb4a95f 100644 --- a/test/indexing.jl +++ b/test/indexing.jl @@ -264,7 +264,7 @@ end f = rand(ComplexF64, N, N) p = rand(ComplexF64, N * N) - I = 1:(N^2) + I = 1:(N ^ 2) out = rand(ComplexF64, M, M) fr = Reactant.to_rarray(f) diff --git a/test/nn/luxlib.jl b/test/nn/luxlib.jl index f1bafff210..b326cbc920 100644 --- a/test/nn/luxlib.jl +++ b/test/nn/luxlib.jl @@ -1,8 +1,9 @@ using LuxLib, Reactant, Enzyme, NNlib @testset "Fused Dense" begin - sumabs2fuseddense(act, weight, x, bias) = - sum(abs2, fused_dense_bias_activation(act, weight, x, bias)) + sumabs2fuseddense(act, weight, x, bias) = sum( + abs2, fused_dense_bias_activation(act, weight, x, bias) + ) function ∇fuseddense(act, weight, x, bias) dw = Enzyme.make_zero(weight)