From 057e6b89ab1b12fd8859f58c60ca29ebee2b628c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 24 Dec 2024 21:46:23 +0530 Subject: [PATCH] fix: handle traced array returns inside objects (#417) * fix: handle traced array returns inside objects * test: add #416 as a test * fix: propagate track_numbers correctly * fix: aliasing and add a test * test: use updated API for the tests * feat: cache new arrays * fix: traced_getfield --- src/Compiler.jl | 120 ++++++++++++++++++++++++++------------ src/Interpreter.jl | 15 ++++- src/TracedUtils.jl | 12 ++++ src/Tracing.jl | 141 ++++++++++++++++++++++++++++++--------------- test/autodiff.jl | 45 +++++++++++++++ test/basic.jl | 28 +++++++++ test/tracing.jl | 8 +-- 7 files changed, 278 insertions(+), 91 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 8fa16cfd3..dae77a067 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -10,6 +10,8 @@ import ..Reactant: ConcreteRNumber, TracedRArray, TracedRNumber, + RArray, + RNumber, OrderedIdDict, make_tracer, TracedToConcrete, @@ -17,9 +19,18 @@ import ..Reactant: TracedType @inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field) -@inline traced_getfield( - @nospecialize(obj::AbstractArray{<:Union{ConcreteRNumber,ConcreteRArray}}), field -) = Base.getindex(obj, field) +@inline function traced_getfield(@nospecialize(obj::AbstractArray{T}), field) where {T} + (isbitstype(T) || obj isa RArray) && return Base.getfield(obj, field) + return Base.getindex(obj, field) +end + +@inline traced_setfield!(@nospecialize(obj), field, val) = Base.setfield!(obj, field, val) +@inline function traced_setfield!( + @nospecialize(obj::AbstractArray{T}), field, val +) where {T} + (isbitstype(T) || obj isa RArray) && return Base.setfield!(obj, field, val) + return Base.setindex!(obj, val, field) +end function create_result(tocopy::T, path, result_stores) where {T} if !isstructtype(typeof(tocopy)) @@ -573,32 +584,32 @@ function codegen_flatten!(linear_args, result_stores) push!(flatten_code, :($usbuf = $flatcode.data)) push!(flatten_code, :($sbuf = XLA.synced_buffer($usbuf))) - # TODO - respaths = ((p for p in arg.paths if p[1] != :args)...,) + # TODO: unused for the time being + # respaths = ((p for p in arg.paths if p[1] == :result || p[1] == :resargs)...,) # resarg = false - for respath in respaths - if respath[1] == :result - flatcode = :result - respath = respath[2:end] - result_stores[respath] = usbuf - resarg = true - else - @assert respath[1] == :resargs - if respath[2] != path[2] - continue - end - # flatcode = :(args[$(respath[2])]) - path = path[3:end] - end - # for p in path - # flatcode = :(traced_getfield($flatcode, $(Meta.quot(p)))) - # end - # resarg = true - # flatcode = :($flatcode.data = $usbuf) - # @show flatcode - # push!(flatten_code, res) - end + # for respath in respaths + # if respath[1] == :result + # flatcode = :result + # respath = respath[2:end] + # result_stores[respath] = usbuf + # resarg = true + # else + # @assert respath[1] == :resargs + # if respath[2] != path[2] + # continue + # end + # # flatcode = :(args[$(respath[2])]) + # path = path[3:end] + # end + # # for p in path + # # flatcode = :(traced_getfield($flatcode, $(Meta.quot(p)))) + # # end + # # resarg = true + # # flatcode = :($flatcode.data = $usbuf) + # # @show flatcode + # # push!(flatten_code, res) + # end # if resarg # push!(resarg_code, :($usbuf = $flatcode.data)) # end @@ -620,11 +631,16 @@ function codegen_unflatten!( concrete_result, result_stores, ) - unflatten_code = Expr[] + cache_dict = gensym("cache_dict") + unflatten_code = Expr[:( + $cache_dict = $(IdDict{ + Union{TracedRArray,TracedRNumber},Union{ConcreteRArray,ConcreteRNumber} + }()) + ),] # mutate the result stores to point to the correct concrete results for (concrete_res_name, result) in zip(concretized_res_names, linear_results) - paths = ((p for p in result.paths if p[1] != :args)...,) + paths = ((p for p in result.paths if p[1] == :result || p[1] == :resargs)...,) for path in paths if path[1] == :result unflatcode = :result @@ -635,15 +651,47 @@ function codegen_unflatten!( @assert path[1] == :resargs unflatcode = :(args[$(path[2])]) path = path[3:end] - end - # unroll path tree - for p in path - unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p)))) - end - unflatcode = :($unflatcode.data = $concrete_res_name) + for p in path[1:(end - 1)] + unflatcode = :(traced_getfield($unflatcode, $(Meta.quot(p)))) + end - push!(unflatten_code, unflatcode) + if length(path) > 0 + final_val = gensym("final_val") + clocal = gensym("clocal") + unflatcode = quote + $final_val = traced_getfield($unflatcode, $(Meta.quot(path[end]))) + if $final_val isa TracedRArray + $clocal = if haskey($cache_dict, $final_val) + $cache_dict[$final_val] + else + $cache_dict[$final_val] = ConcreteRArray{ + eltype($final_val),ndims($final_val) + }( + $concrete_res_name, size($final_val) + ) + $cache_dict[$final_val] + end + traced_setfield!($unflatcode, $(Meta.quot(path[end])), $clocal) + elseif $final_val isa TracedRNumber + $clocal = if haskey($cache_dict, $final_val) + $cache_dict[$final_val] + else + $cache_dict[$final_val] = ConcreteRNumber{eltype($final_val)}( + $concrete_res_name + ) + $cache_dict[$final_val] + end + traced_setfield!($unflatcode, $(Meta.quot(path[end])), $clocal) + else + traced_setfield!($final_val, :data, $concrete_res_name) + end + end + else + unflatcode = :($unflatcode.data = $concrete_res_name) + end + push!(unflatten_code, unflatcode) + end end end diff --git a/src/Interpreter.jl b/src/Interpreter.jl index f5e4475a2..f75e57cfa 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -235,7 +235,7 @@ function overload_autodiff( primf = f.val primargs = ((v.val for v in args)...,) - fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = Reactant.TracedUtils.make_mlir_fn( + fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = TracedUtils.make_mlir_fn( primf, primargs, (), string(f) * "_autodiff", false ) @@ -302,7 +302,7 @@ function overload_autodiff( cst = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1) push!(ad_inputs, cst) end - else + elseif TracedUtils.has_argidx(a) idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap act = act_from_type(f, reverse, true) @@ -322,6 +322,12 @@ function overload_autodiff( end TracedUtils.push_val!(ad_inputs, args[idx].dval, path[3:end]) end + else + act = act_from_type(Enzyme.Const, reverse, true) + push!(ret_activity, act) + if act != enzyme_out && act != enzyme_outnoneed + continue + end end end @@ -385,7 +391,7 @@ function overload_autodiff( end residx += 1 end - else + elseif TracedUtils.has_argidx(a) idx, path = TracedUtils.get_argidx(a) if idx == 1 && fnwrap TracedUtils.set!( @@ -405,6 +411,9 @@ function overload_autodiff( ) residx += 1 end + else + TracedUtils.set!(a, (), TracedUtils.transpose_val(MLIR.IR.result(res, residx))) + residx += 1 end end diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 159a98a85..665c6732d 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -341,6 +341,18 @@ function get_argidx(x) throw(AssertionError("No path found for $x")) end +function has_argidx(x) + for path in x.paths + if length(path) == 0 + continue + end + if path[1] == :args + return true + end + end + return false +end + function set!(x, path, tostore; emptypath=false) for p in path x = Reactant.Compiler.traced_getfield(x, p) diff --git a/src/Tracing.jl b/src/Tracing.jl index 62bb71f69..5b1f1e72c 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -6,32 +6,34 @@ TracedSetPath = 5 end -for T in ( - DataType, - Module, - Nothing, - Symbol, - AbstractChar, - AbstractFloat, - Integer, - AbstractString, - RArray, - RNumber, -) - @eval function traced_type(::Type{T}, seen, mode) where {T<:$T} +for T in (DataType, Module, Nothing, Symbol, AbstractChar, AbstractString, RArray, RNumber) + @eval function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:$T} return T end end -function traced_type(::Type{C}, seen::ST, mode::Val{Mode}) where {T,C<:Complex{T},ST,Mode} +function traced_type( + ::Type{T}, seen, mode::Val{Mode}, track_numbers +) where {T<:Union{AbstractFloat,Integer},Mode} + if Mode == ArrayToConcrete && any(Base.Fix1(<:, T), track_numbers) + return ConcreteRNumber{T} + end + return T +end + +function traced_type( + ::Type{C}, seen::ST, mode::Val{Mode}, track_numbers::TN +) where {T,C<:Complex{T},ST,Mode,TN} if !(C isa UnionAll) - return Complex{traced_type(T, seen, mode)} + return Complex{traced_type(T, seen, mode, track_numbers)} else - return @invoke traced_type(C::Type{Any}, seen::ST, mode::Val{Mode}) + return @invoke traced_type( + C::Type{Any}, seen::ST, mode::Val{Mode}, track_numbers::TN + ) end end -function traced_type(::Type{T}, seen, mode) where {T<:Function} +function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Function} # functions are directly returned if sizeof(T) == 0 return T @@ -41,7 +43,7 @@ function traced_type(::Type{T}, seen, mode) where {T<:Function} N = fieldcount(T) changed = false traced_fieldtypes = ntuple(Val(N)) do i - next = traced_type(fieldtype(T, i), seen, mode) + next = traced_type(fieldtype(T, i), seen, mode, track_numbers) changed |= next != fieldtype(T, i) next end @@ -57,31 +59,34 @@ end @inline is_concrete_tuple(x::T2) where {T2} = (x <: Tuple) && !(x === Tuple) && !(x isa UnionAll) -function traced_type(::Type{T}, seen, mode) where {T<:Tuple} +function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:Tuple} if !Base.isconcretetype(T) || !is_concrete_tuple(T) || T isa UnionAll throw(AssertionError("Type $T is not concrete type or concrete tuple")) elseif is_concrete_tuple(T) && any(T2 isa Core.TypeofVararg for T2 in T.parameters) # Tuple{((T2 isa Core.TypeofVararg ? Any : T2) for T2 in T.parameters)...} throw(AssertionError("Type tuple of vararg $T is not supported")) end - TT = [traced_type(T.parameters[i], seen, mode) for i in 1:length(T.parameters)] + TT = [ + traced_type(T.parameters[i], seen, mode, track_numbers) for + i in 1:length(T.parameters) + ] return Tuple{TT...} end -function traced_type(::Type{T}, seen, mode) where {N,V,T<:NamedTuple{N,V}} - return NamedTuple{N,traced_type(V, seen, mode)} +function traced_type(::Type{T}, seen, mode, track_numbers) where {N,V,T<:NamedTuple{N,V}} + return NamedTuple{N,traced_type(V, seen, mode, track_numbers)} end -function traced_type(::Type{T}, seen, mode) where {K,V,T<:AbstractDict{K,V}} +function traced_type(::Type{T}, seen, mode, track_numbers) where {K,V,T<:AbstractDict{K,V}} dictty = T.name.wrapper - return dictty{K,traced_type(V, seen, mode)} + return dictty{K,traced_type(V, seen, mode, track_numbers)} end @inline getmap(::Val{T}) where {T} = nothing @inline getmap(::Val{T}, a, b, args...) where {T} = getmap(Val(T), args...) @inline getmap(::Val{T}, ::Val{T}, ::Val{T2}, args...) where {T,T2} = T2 -function traced_type(::Type{T}, seen, mode) where {T} +function traced_type(::Type{T}, seen, mode, track_numbers) where {T} if T === Any return T end @@ -110,7 +115,10 @@ function traced_type(::Type{T}, seen, mode) where {T} end if T isa Union - return Union{traced_type(T.a, seen, mode),traced_type(T.b, seen, mode)} + return Union{ + traced_type(T.a, seen, mode, track_numbers), + traced_type(T.b, seen, mode, track_numbers), + } end # if abstract it must be by reference @@ -133,7 +141,7 @@ function traced_type(::Type{T}, seen, mode) where {T} subTys = Type[] for f in 1:fieldcount(T) subT = fieldtype(T, f) - subTT = traced_type(subT, seen2, mode) + subTT = traced_type(subT, seen2, mode, track_numbers) changed |= subT != subTT push!(subTys, subTT) end @@ -145,7 +153,7 @@ function traced_type(::Type{T}, seen, mode) where {T} subParms = [] for SST in T.parameters if SST isa Type - TrT = traced_type(SST, seen, mode) + TrT = traced_type(SST, seen, mode, track_numbers) push!(subParms, TrT) else push!(subParms, SST) @@ -163,7 +171,7 @@ function traced_type(::Type{T}, seen, mode) where {T} for f in 1:fieldcount(T) subT = fieldtype(T, f) subT2 = fieldtype(TT2, f) - subTT = traced_type(subT, seen3, mode) + subTT = traced_type(subT, seen3, mode, track_numbers) if subT2 != subTT legal = false break @@ -178,7 +186,9 @@ function traced_type(::Type{T}, seen, mode) where {T} throw(NoFieldMatchError(T, TT2)) end -function traced_type(::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}) where {T,mode} +function traced_type( + ::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}, track_numbers +) where {T,mode} if mode == ConcreteToTraced return TracedRNumber{T} elseif mode == TracedToConcrete @@ -188,7 +198,9 @@ function traced_type(::Type{<:ConcreteRNumber{T}}, seen, ::Val{mode}) where {T,m end end -function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode} +function traced_type( + ::Type{T}, seen, ::Val{mode}, track_numbers +) where {T<:ConcreteRArray,mode} if mode == ConcreteToTraced @inline base_typet(TV::TT) where {TT<:UnionAll} = UnionAll(TV.var, base_typet(TV.body)) @@ -201,7 +213,9 @@ function traced_type(::Type{T}, seen, ::Val{mode}) where {T<:ConcreteRArray,mode end end -function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,mode} +function traced_type( + ::Type{T}, seen::ST, ::Val{mode}, track_numbers +) where {ST,T<:TracedType,mode} T <: MissingTracedValue && error("TODO") if mode == ConcreteToTraced throw("TracedRArray $T cannot be traced") @@ -218,26 +232,28 @@ function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,m end end -function traced_type(::Type{T}, seen, mode) where {T<:XLAArray} +function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:XLAArray} throw("XLA $T array cannot be traced") end -function traced_type(::Type{A}, seen::ST, ::Val{mode}) where {T,N,A<:Array{T,N},ST,mode} +function traced_type( + ::Type{A}, seen::ST, ::Val{mode}, track_numbers +) where {T,N,A<:Array{T,N},ST,mode} if mode == ArrayToConcrete && T <: ReactantPrimitive return ConcreteRArray{T,N} else - return Array{traced_type(T, seen, Val(mode)),N} + return Array{traced_type(T, seen, Val(mode), track_numbers),N} end end for P in (Ptr, Core.LLVMPtr, Base.RefValue) - @eval function traced_type(::Type{P}, seen, mode) where {T,P<:$P{T}} - return $P{traced_type(T, seen, mode)} + @eval function traced_type(::Type{P}, seen, mode, track_numbers) where {T,P<:$P{T}} + return $P{traced_type(T, seen, mode, track_numbers)} end end -function traced_type(::Type{Val{T}}, seen, mode) where {T} - if traced_type(typeof(T), seen, mode) == typeof(T) +function traced_type(::Type{Val{T}}, seen, mode, track_numbers) where {T} + if traced_type(typeof(T), seen, mode, track_numbers) == typeof(T) return Val{T} end throw("Val type $(Val{T}) cannot be traced") @@ -274,12 +290,13 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, + track_numbers=(), kwargs..., ) where {RT} if haskey(seen, prev) return seen[prev] end - TT = traced_type(RT, (), Val(mode)) + TT = traced_type(RT, (), Val(mode), track_numbers) @assert !Base.isabstracttype(RT) @assert Base.isconcretetype(RT) nf = fieldcount(RT) @@ -295,7 +312,16 @@ function make_tracer( for i in 1:nf if isdefined(prev, i) xi = Base.getfield(prev, i) - xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch) + xi2 = make_tracer( + seen, + xi, + append_path(path, i), + mode; + toscalar, + tobatch, + track_numbers, + kwargs..., + ) if xi !== xi2 changed = true end @@ -318,7 +344,16 @@ function make_tracer( for i in 1:nf if isdefined(prev, i) xi = Base.getfield(prev, i) - xi2 = make_tracer(seen, xi, append_path(path, i), mode; toscalar, tobatch) + xi2 = make_tracer( + seen, + xi, + append_path(path, i), + mode; + toscalar, + tobatch, + track_numbers, + kwargs..., + ) if xi !== xi2 changed = true end @@ -543,7 +578,7 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... + seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... ) where {RT<:Array} if haskey(seen, prev) return seen[prev] @@ -551,14 +586,14 @@ function make_tracer( if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive return seen[prev] = ConcreteRArray(prev) end - TT = traced_type(eltype(RT), (), Val(mode)) + TT = traced_type(eltype(RT), (), Val(mode), track_numbers) newa = Array{TT,ndims(RT)}(undef, size(prev)) seen[prev] = newa same = true for I in eachindex(prev) if isassigned(prev, I) pv = prev[I] - nv = make_tracer(seen, pv, append_path(path, I), mode; kwargs...) + nv = make_tracer(seen, pv, append_path(path, I), mode; track_numbers, kwargs...) if pv !== nv same = false end @@ -584,12 +619,22 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::NamedTuple{A,RT}), @nospecialize(path), mode; kwargs... + seen, + @nospecialize(prev::NamedTuple{A,RT}), + @nospecialize(path), + mode; + track_numbers=(), + kwargs..., ) where {A,RT} - return NamedTuple{A,traced_type(RT, (), Val(mode))}(( + return NamedTuple{A,traced_type(RT, (), Val(mode), track_numbers)}(( ( make_tracer( - seen, Base.getfield(prev, i), append_path(path, i), mode; kwargs... + seen, + Base.getfield(prev, i), + append_path(path, i), + mode; + track_numbers, + kwargs..., ) for i in 1:length(A) )..., )) diff --git a/test/autodiff.jl b/test/autodiff.jl index 842050413..044799bcb 100644 --- a/test/autodiff.jl +++ b/test/autodiff.jl @@ -75,3 +75,48 @@ end @test typeof(res) == Tuple{Enzyme.TupleArray{ConcreteRNumber{Float64},(2, 2),4,2}} @test res[1] ≈ ones(2, 2) end + +mutable struct StateReturn + st::Any +end + +mutable struct StateReturn1 + st1::Any + st2::Any +end + +function cached_return(x, stret::StateReturn) + loss = sum(x) + stret.st = x .+ 1 + return loss +end + +function cached_return(x, stret::StateReturn1) + loss = sum(x) + tmp = x .+ 1 + stret.st1 = tmp + stret.st2 = tmp + return loss +end + +@testset "Cached Return: Issue #416" begin + x = rand(10) + x_ra = Reactant.to_rarray(x) + + stret = StateReturn(nothing) + ret = @jit Enzyme.gradient(Reverse, cached_return, x_ra, Const(stret)) + + @test @allowscalar all(isone, ret[1]) + @test stret.st isa ConcreteRArray + @test stret.st ≈ x .+ 1 + + stret = StateReturn1(nothing, nothing) + ret = @jit Enzyme.gradient(Reverse, cached_return, x_ra, Const(stret)) + + @test @allowscalar all(isone, ret[1]) + @test stret.st1 isa ConcreteRArray + @test stret.st1 ≈ x .+ 1 + @test stret.st2 isa ConcreteRArray + @test stret.st2 ≈ x .+ 1 + @test stret.st1 === stret.st2 +end diff --git a/test/basic.jl b/test/basic.jl index 31fd841b4..edb0c0e35 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -676,3 +676,31 @@ end fill!(z, 1.0) @test all(==(1.0), Array(z)) end + +@testset "Preserve Aliasing" begin + x = Reactant.to_rarray([3]) + T = Any[nothing] + + function ip(m, T) + @allowscalar m[1] = 2 + T[1] = m + return m + end + + res = @jit ip(x, T) + @test @allowscalar res[1] == 2 + @test @allowscalar x[1] == 2 + @test @allowscalar T[1][1] == 2 + + ptr_x = Base.unsafe_convert( + Ptr{Float64}, Reactant.XLA.UnsafeBufferPointer(x.data.buffer) + ) + ptr_res = Base.unsafe_convert( + Ptr{Float64}, Reactant.XLA.UnsafeBufferPointer(res.data.buffer) + ) + ptr_T1 = Base.unsafe_convert( + Ptr{Float64}, Reactant.XLA.UnsafeBufferPointer(T[1].data.buffer) + ) + + @test ptr_x == ptr_res == ptr_T1 +end diff --git a/test/tracing.jl b/test/tracing.jl index f817dbfdc..3b73212dc 100644 --- a/test/tracing.jl +++ b/test/tracing.jl @@ -90,7 +90,7 @@ using Test (Val{:x}, Val{:x}), ] tracedty = traced_type( - origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced) + origty, Reactant.OrderedIdDict(), Val(ConcreteToTraced), () ) @test tracedty == targetty end @@ -102,13 +102,13 @@ using Test TracedRArray{Float64,3}, ] @test_throws Union{ErrorException,String} traced_type( - type, Reactant.OrderedIdDict(), Val(ConcreteToTraced) + type, Reactant.OrderedIdDict(), Val(ConcreteToTraced), () ) end end @testset "traced_type exceptions" begin @test_throws TracedTypeError Reactant.traced_type( - Real, Reactant.OrderedIdDict(), Val(Reactant.ArrayToConcrete) + Real, Reactant.OrderedIdDict(), Val(Reactant.ArrayToConcrete), () ) struct Node @@ -116,7 +116,7 @@ using Test y::Union{Nothing,Node} end @test_throws NoFieldMatchError traced_type( - Node, Reactant.OrderedIdDict(), Val(ArrayToConcrete) + Node, Reactant.OrderedIdDict(), Val(ArrayToConcrete), () ) end end