diff --git a/Project.toml b/Project.toml index 2e535e5..a506c37 100644 --- a/Project.toml +++ b/Project.toml @@ -15,9 +15,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -CUDA = "1.2, 2.3" +CUDA = "1.2, 2" ChainRulesCore = "0.9.5" Distributions = "0.23.2" Espresso = "0.6.0" -IRTools = "0.4.0" +IRTools = "0.4" julia = "1.4" diff --git a/src/alloc.jl b/src/alloc.jl deleted file mode 100644 index f2c8947..0000000 --- a/src/alloc.jl +++ /dev/null @@ -1,65 +0,0 @@ -## Memory allocation strategies - -abstract type AbstractMemoryPool end - - -mutable struct SimplePool <: AbstractMemoryPool -end - - -alloc(mp::SimplePool, T, sz) = T(undef, sz...) -free(mp::SimplePool, a) = () - - -mutable struct CachingPool - cache::Dict{Any, Any} # (T, sz) => [buffer1, buffer2, ...] -end -CachingPool() = CachingPool(Dict()) - - -function alloc(mp::CachingPool, T, sz) - key = (T, sz) - if haskey(mp.cache, key) - buffers = mp.cache[key] - arr = pop!(buffers) - isempty(buffers) && delete!(mp.cache, key) - return arr - else - try - return T(undef, sz) - catch e - if e isa OutOfMemoryError - # worst case - release all cached memory, run GC and repeat attempt - mp.cache = Dict() - GC.gc() - return T(undef, sz) - end - end - end -end - -function free(mp::CachingPool, arr) - key = (typeof(arr), size(arr)) - if !haskey(mp.cache, key) - mp.cache[key] = Any[] - end - push!(mp.cache[key], arr) -end - - - - -using LinearAlgebra - -function usage() - a = rand(5, 4) |> cu - b = rand(4, 10) |> cu - - mp = CachingPool() - # alloc x - x = alloc(mp, CuArray{Float32, 2}, (5, 10)) - mul!(x, a, b) - # free x when it's not needed anymore - free(mp, x) - x = nothing -end diff --git a/src/core.jl b/src/core.jl index 2246c5f..9c7672a 100644 --- a/src/core.jl +++ b/src/core.jl @@ -19,13 +19,14 @@ include("compile.jl") include("update.jl") include("transform.jl") include("cuda.jl") +include("gradcheck.jl") const BEST_AVAILABLE_DEVICE = Ref{AbstractDevice}(CPU()) if CUDA.functional() try - BEST_AVAILABLE_DEVICE[] = GPU(1) + BEST_AVAILABLE_DEVICE[] = GPU(1) catch ex # something is wrong with the user's set-up (or there's a bug in CuArrays) @warn "CUDA is installed, but not working properly" exception=(ex,catch_backtrace()) @@ -33,5 +34,4 @@ if CUDA.functional() end end - best_available_device() = BEST_AVAILABLE_DEVICE[] diff --git a/src/cuda.jl b/src/cuda.jl index 5942489..399b9b7 100644 --- a/src/cuda.jl +++ b/src/cuda.jl @@ -1,8 +1,3 @@ -# import CUDAnative -# using CuArrays - - -# CuArrays.cufunc(::typeof(^)) = CUDAnative.pow @diffrule CUDA.exp(u::Real) u CUDA.exp(u) * dy @diffrule CUDA.pow(u::Real, v::Real) u (v * CUDA.pow(u, (v-1)) * dy) @@ -44,32 +39,4 @@ function to_device(device::GPU, x) fld_vals = [to_device(device, getfield(x, fld)) for fld in flds] return T(fld_vals...) end -end - - -# function cuarray_compatible_tform(tape::Tape) -# new_tape = similar(tape) -# changed = false -# for op in tape -# if op isa Call && haskey(CUDANATIVE_OPS, op.fn) -# changed = true -# push!(new_tape, copy_with(op, fn=CUDANATIVE_OPS[op.fn])) -# else -# push!(new_tape, op) -# end -# end -# return new_tape, changed -# end - - -# """ -# Transform function to CuArrays compatible. -# """ -# function cuda_compatible(f, args) -# cf = CuArrays.cufunc(f) -# if f === cf -# return cf -# else -# return transform(cuarray_compatible_tform, f, args) -# end -# end +end \ No newline at end of file diff --git a/src/devices.jl b/src/devices.jl index 876d62e..f006edb 100644 --- a/src/devices.jl +++ b/src/devices.jl @@ -43,5 +43,9 @@ to_device(device::CPU, f::Function, args) = f (device::CPU)(x) = to_device(device, x) (device::GPU)(x) = to_device(device, x) +to_cpu(A) = A +to_cpu(A::CuArray) = convert(Array, A) +to_cuda(A) = cu(A) +to_cuda(A::CuArray) = A to_same_device(A, example) = device_of(example)(A) diff --git a/src/diffrules/basic.jl b/src/diffrules/basic.jl index 322b460..3038606 100644 --- a/src/diffrules/basic.jl +++ b/src/diffrules/basic.jl @@ -131,12 +131,6 @@ @diffrule __getfield__(_s::Tuple, _f) _s ∇getfield(dy, _s, _f) @nodiff __getfield__(_s, _f::Tuple) _f -# @nodiff __new__(t, u) t -# @nodiff __new__(t, u, v) t -# @nodiff __new__(t, u, v, w) t -# @nodiff __new__(t, u, v) t -# @diffrule __new__(t, u, v) u ∇__new__(dy, t, 1) -# @diffrule __new__(t, u, v) v ∇__new__(dy, t, 2) const LONG_VAR_NAME_LIST = (:x, :u, :v, :w, :_a, :_b, :_c, :_d, :_e, :_f, :_g) for n=1:length(LONG_VAR_NAME_LIST) @@ -149,6 +143,23 @@ for n=1:length(LONG_VAR_NAME_LIST) end end +# Base.iterate + +# here we explicitely stop propagation in iteration +# over ranges (e.g for i=1:3 ... end) +@nodiff Base.iterate(x::UnitRange) x +@nodiff Base.iterate(x::UnitRange, i::Int) x +@nodiff Base.iterate(x::UnitRange, i::Int) i + +@diffrule Base.iterate(t::Tuple) t ∇getfield(getindex(dy, 1), t, 1) +@diffrule Base.iterate(t::Tuple, i::Int) t ∇getfield(getindex(dy, 1), t, i) +@nodiff Base.iterate(t::Tuple, i::Int) i + +@diffrule Base.iterate(x::AbstractArray) x ungetindex(x, dy, 1) +@diffrule Base.iterate(x::AbstractArray, i::Int) x ungetindex(x, dy, i) +@nodiff Base.iterate(x::AbstractArray, i::Int) i + + # Base.indexed_iterate (tuple unpacking) @diffrule Base.indexed_iterate(t::Tuple, i::Int) t ∇getfield(getindex(dy, 1), t, i) diff --git a/src/diffrules/diffrules.jl b/src/diffrules/diffrules.jl index 755e2b2..3006e21 100644 --- a/src/diffrules/diffrules.jl +++ b/src/diffrules/diffrules.jl @@ -11,7 +11,7 @@ const DIFF_PHS = Set([:x, :u, :v, :w, :t, :i, :j, :k,]) function reset_rules!() empty!(DIFF_RULES) empty!(NO_DIFF_RULES) - empty!(CONSTRUCTORS) + empty!(CONSTRUCTORS) end @@ -142,7 +142,7 @@ end Define a type constructor that should not be traced, but instead recorded to the tape as is. Here's an example: - @ctor MvNormal(μ, Σ) + @ctor MvNormal(μ, Σ) This should be read as: @@ -158,7 +158,7 @@ Yota will _completely bypass_ internals of the constructor and jump directly to 1st variable passed to MvNormal(). Note that if you don't want to bypass the constructor (which you normally shouldn't do), -you can rely on Yota handling it automatically. +you can rely on Yota handling it automatically. """ macro ctor(ex) @@ -256,7 +256,7 @@ function deriv_exprs(ex, dep_types, idx::Int) result = Tuple[] # list of tuples (rewritten_expr, field_name | nothing) for rule in DIFF_RULES m, fldname = match_rule(rule, ex, dep_types, idx) - if m != nothing + if m !== nothing pat, rpat = m rex = rewrite_with_keywords(ex, pat, rpat) push!(result, (rex, fldname)) @@ -270,9 +270,9 @@ function deriv_exprs(ex, dep_types, idx::Int) # record it to the tape as constant; instead we must rewrite forward pass as well # replacing all f(args...) with _, pb = rrule(f, args...) # which breaks significant part of codebase - # @assert Meta.isexpr(ex, :call) + # @assert Meta.isexpr(ex, :call) # _, df = rrule(ex.args[1], 2.0) - # dex = Expr(:call, df, ) + # dex = Expr(:call, df, ) error("Can't find differentiation rule for $ex at $idx " * "with types $dep_types)") end @@ -331,7 +331,7 @@ function dont_diff(tape::Tape, op::AbstractOp, idx::Int) dep_types = [eltype(tape[arg].typ) for arg in op.args[2:end]] idx_ = idx - 1 else - ex = to_expr(tape, op) + ex = to_expr(tape, op) dep_types = [tape[arg].typ for arg in op.args] idx_ = idx end diff --git a/src/grad.jl b/src/grad.jl index 24ed0c9..8c76977 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -2,31 +2,6 @@ # GRAD RESULT # ######################################################################## -# function field_paths(tape::Tape) -# paths = Dict() -# for op in reverse(tape.ops) -# _op = op -# path = [] -# while _op isa Call && _op.fn in (Base.getproperty, -# Base.getfield, -# __getfield__) -# field_name = tape[_op.args[2]].val -# push!(path, field_name) -# _op_id = _op.args[1] -# _op = tape[_op_id] -# end -# if !isempty(path) -# struct_id = _op.id -# if !haskey(paths, struct_id) -# paths[struct_id] = Dict() -# end -# paths[struct_id][(reverse(path)...,)] = op.id -# end -# end -# return paths -# end - - struct GradResult tape::Tape gvars::Vector{Any} # gradient vars @@ -89,14 +64,14 @@ function deriv!(tape::Tape, op::AbstractOp, i::Int, dy::AbstractOp) st = Dict(Symbol("%$i") => i for i in op.args) st[:dy] = dy.id st[:y] = op.id - if dex_fldnames[1][2] == nothing + if dex_fldnames[1][2] === nothing # not a derivative of a field - take only the 1st match dex_fldnames = dex_fldnames[1:1] end op_deriv_attrs = Tuple[] for (dex, fldname) in dex_fldnames ret_id = record_expr!(tape, dex; st=st) - derivative_of = (fldname == nothing ? tape[op.args[i]] : + derivative_of = (fldname === nothing ? tape[op.args[i]] : field_var_from_ctor_op(tape, tape[op.args[i]], fldname)) push!(op_deriv_attrs, (tape[ret_id], derivative_of)) end diff --git a/test/gradcheck.jl b/src/gradcheck.jl similarity index 81% rename from test/gradcheck.jl rename to src/gradcheck.jl index ec57598..ff94d51 100644 --- a/test/gradcheck.jl +++ b/src/gradcheck.jl @@ -16,14 +16,6 @@ function ngradient(f, xs::AbstractArray...) end -# function gradcheck(f, xs...) -# n_grads = ngradient(f, xs...) -# y_grads = Yota._grad(f, xs...)[2] |> collect -# all(isapprox.(n_grads, y_grads, rtol = 1e-5, atol = 1e-5)) -# end - - - function ngradient2(f, xs, n) x = xs[n] Δ = zero(x) @@ -42,7 +34,7 @@ end function gradcheck2(f, xs...; var_idxs=1:length(xs)) - y_grads = Yota._grad(f, xs...)[2] |> collect + y_grads = _grad(f, xs...)[2] |> collect results = [] for n in var_idxs n_grad = ngradient2(f, xs, n) diff --git a/src/helpers.jl b/src/helpers.jl index 49b482d..5690de0 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -17,8 +17,32 @@ function ∇__new__(dy, T, idx) end +# TODO (when https://github.com/FluxML/NNlib.jl/pull/296 is done): +# 1. uncomment this implementation +# 2. remove the next 2 functions +# +# function ungetindex!(dx::AbstractArray, ::AbstractArray, dy, I...) +# if dy isa Number +# dy = to_same_device([dy], dx) +# end +# I = Any[i for i in I] +# for (d, i) in enumerate(I) +# if i == (:) +# I[d] = 1:size(dx, d) +# end +# if i isa Number +# I[d] = [i] +# end +# I[d] = to_cpu(I[d]) +# end +# # cartesian product of all concerned indices +# idx = collect(Iterators.product(I...)) +# idx = to_same_device(idx, dx) +# return scatter!(+, dx, dy, idx) +# end + function ungetindex!(dx::AbstractArray, ::AbstractArray, dy, I...) - return Scatter.scatter_add2!(dx, dy, I...) + return Scatter.scatter_add2!(dx, dy, I...) end diff --git a/src/tape.jl b/src/tape.jl index 4f542b8..9abda87 100644 --- a/src/tape.jl +++ b/src/tape.jl @@ -49,8 +49,6 @@ mutable struct Call <: AbstractOp args::Vector{Int} end -# Call(id::Int, val::Any, fn::Union{Function, Type}, args::Vector{Int}) = -# Call(id, val, fn, args) Base.getproperty(op::Input, f::Call) = f == :typ ? typeof(op.val) : getfield(op, f) @@ -107,9 +105,6 @@ function Base.show(io::IO, tape::Tape) end Base.getindex(tape::Tape, i...) = getindex(tape.ops, i...) -# Base.getindex(tape::Tape, i::String...) = -# getindex(tape.ops, [parse(Int, s[2:end]) for s in i]...) -# Base.getindex(tape::Tape, i::Symbol...) = getindex(tape, map(string, i)...) Base.setindex!(tape::Tape, op::AbstractOp, i::Integer) = (tape.ops[i] = op) Base.lastindex(tape::Tape) = lastindex(tape.ops) Base.length(tape::Tape) = length(tape.ops) @@ -194,8 +189,6 @@ function record_expr!(tape::Tape, x; st, bcast=false) end - - ######################################################################## # TRANSFORMATIONS # ######################################################################## @@ -284,7 +277,6 @@ end # end - function squash_assigned(tape::Tape) new_tape = copy_with(tape, ops=AbstractOp[]) st = Dict{Int, Int}() # substitution table for op indices @@ -328,40 +320,40 @@ end -""" -Unwind iterate() sequences into plain __getfield__ expressions. -unwind_iterate() doesn't remove unused elements for performance reasons, -so remove_unused() should be called after it. -""" -function unwind_iterate(tape::Tape) - tape = copy_with(tape) - for op in tape - if (op isa Call && op.fn in (getfield, __getfield__) - && tape[op.args[1]] isa Call && tape[op.args[1]].fn == Base.iterate - && tape[op.args[2]] isa Constant && tape[op.args[2]].val == 1) - iterate_op = tape[op.args[1]] - iterable_op = tape[iterate_op.args[1]] - idx = length(iterate_op.args) > 1 ? tape[iterate_op.args[2]].val : 1 - if iterable_op.val isa Tuple || iterable_op.val isa Vector || iterable_op.val isa UnitRange - # 1. Replace iterable op with index in the original iterable - tape[iterate_op.id] = Constant(iterate_op.id, idx) - # 2. Replace __getfield__ on iterator with __getfield__ or getindex on original iterable - idx_id = iterate_op.id - obj_id = iterable_op.id - # TODO: in which other cases getindex is better than __getfield__? - get_op = iterable_op.val isa UnitRange ? getindex : __getfield__ - tape[op.id] = Call(op.id, op.val, get_op, [obj_id, idx_id]) - end - end - end - return tape -end +# """ +# Unwind iterate() sequences into plain __getfield__ expressions. +# unwind_iterate() doesn't remove unused elements for performance reasons, +# so remove_unused() should be called after it. +# """ +# function unwind_iterate(tape::Tape) +# tape = copy_with(tape) +# for op in tape +# if (op isa Call && op.fn in (getfield, __getfield__) +# && tape[op.args[1]] isa Call && tape[op.args[1]].fn == Base.iterate +# && tape[op.args[2]] isa Constant && tape[op.args[2]].val == 1) +# iterate_op = tape[op.args[1]] +# iterable_op = tape[iterate_op.args[1]] +# idx = length(iterate_op.args) > 1 ? tape[iterate_op.args[2]].val : 1 +# if iterable_op.val isa Tuple || iterable_op.val isa Vector || iterable_op.val isa UnitRange +# # 1. Replace iterable op with index in the original iterable +# tape[iterate_op.id] = Constant(iterate_op.id, idx) +# # 2. Replace __getfield__ on iterator with __getfield__ or getindex on original iterable +# idx_id = iterate_op.id +# obj_id = iterable_op.id +# # TODO: in which other cases getindex is better than __getfield__? +# get_op = iterable_op.val isa UnitRange ? getindex : __getfield__ +# tape[op.id] = Call(op.id, op.val, get_op, [obj_id, idx_id]) +# end +# end +# end +# return tape +# end function simplify(tape::Tape) tape = recover_broadcast(tape) tape = squash_assigned(tape) - tape = unwind_iterate(tape) + # tape = unwind_iterate(tape) tape = eliminate_common(tape) tape = remove_unused(tape) return tape @@ -373,8 +365,8 @@ end ######################################################################## -exec!(tape::Tape, op::Input) = () -exec!(tape::Tape, op::Constant) = () +exec!(::Tape, ::Input) = () +exec!(::Tape, ::Constant) = () exec!(tape::Tape, op::Assign) = (op.val = tape[op.src_id].val) exec!(tape::Tape, op::Call) = (op.val = op.fn([tape[id].val for id in op.args]...)) @@ -384,7 +376,7 @@ function play!(tape::Tape, args...; use_compiled=true, debug=false) @assert(tape[i] isa Input, "More arguments than the original function had") tape[i].val = val end - if use_compiled && tape.compiled != nothing + if use_compiled && tape.compiled !== nothing Base.invokelatest(tape.compiled) else for op in tape diff --git a/src/tapeutils.jl b/src/tapeutils.jl index 7df9bf4..3021dbf 100644 --- a/src/tapeutils.jl +++ b/src/tapeutils.jl @@ -216,7 +216,7 @@ function field_var_from_ctor_op(tape::Tape, ctor_op::Call, fldname::Symbol) for ctor_def in CONSTRUCTORS if ctor_typ == ctor_def[1] fld_idx_in_ctor_def = findfirst(isequal(fldname), ctor_def[2:end]) - if fld_idx_in_ctor_def != nothing + if fld_idx_in_ctor_def !== nothing # tape ID of corresponding argument to the constructor arg_id = ctor_op.args[fld_idx_in_ctor_def] return tape[arg_id] @@ -227,9 +227,6 @@ function field_var_from_ctor_op(tape::Tape, ctor_op::Call, fldname::Symbol) end -# TODO: make unified API for 2 versions of this function -# 1st - for __new__ constuctor -# 2nd - for primitive constructor created using @ctor macro function field_var_from_ctor_op(tape::Tape, ctor::Call, getprop_op::Call) @assert ctor.fn == __new__ @assert getprop_op.fn == Base.getproperty @@ -241,54 +238,54 @@ function field_var_from_ctor_op(tape::Tape, ctor::Call, getprop_op::Call) end -""" -Given a tape and getproperty() operation, try to find a variable -that was used to create that field -""" -function find_field_source_var(tape::Tape, getprop_op::Call) - parent = tape[getprop_op.args[1]] - if parent isa Call && parent.fn == __new__ - # found constructor for this field, return variable that made up getprop_op's field - return field_var_from_ctor_op(tape, parent, getprop_op) - elseif parent isa Call && parent.fn == Base.getproperty - # nested getproperty, find constructor for the current struct recursively - ctor = find_field_source_var(tape, parent) - if ctor != nothing - return field_var_from_ctor_op(tape, ctor, getprop_op) - else - return nothing - end - else - # can't find source field - give up and return nothing - return nothing - end -end - - -""" -Given a tape and getfield() operation, try to find a variable -that was used to create that field. -NOTE: This works only with getfield() on tuples -""" -function find_tuple_field_source_var(tape::Tape, getf_op::Call) - getf_base_op = tape[getf_op.args[1]] - if getf_base_op.fn == Base.indexed_iterate - tuple_op = tape[getf_base_op.args[1]] - # tuple_idx = tape[ind_it_op.args[2]].val - tuple_idx = tape[getf_base_op.args[2]].val - elseif getf_base_op.fn == __tuple__ - tuple_op = getf_base_op - tuple_idx = tape[getf_op.args[2]].val - elseif getf_base_op.fn == namedtuple - tuple_op = tape[getf_base_op.args[2]] - tuple_fld = tape[getf_op.args[2]].val - flds = fieldnames(typeof(getf_base_op.val)) - tuple_idx = findfirst(x -> x == tuple_fld, flds) - # should make recursive? - else - throw(AssertionError("Unexpected base op for __getfield__: $(getf_base_op.fn)")) - end - # @assert tuple_op isa Call && tuple_op.fn == __tuple__ - src_var = tape[tuple_op.args[tuple_idx]] - return src_var -end +# """ +# Given a tape and getproperty() operation, try to find a variable +# that was used to create that field +# """ +# function find_field_source_var(tape::Tape, getprop_op::Call) +# parent = tape[getprop_op.args[1]] +# if parent isa Call && parent.fn == __new__ +# # found constructor for this field, return variable that made up getprop_op's field +# return field_var_from_ctor_op(tape, parent, getprop_op) +# elseif parent isa Call && parent.fn == Base.getproperty +# # nested getproperty, find constructor for the current struct recursively +# ctor = find_field_source_var(tape, parent) +# if ctor != nothing +# return field_var_from_ctor_op(tape, ctor, getprop_op) +# else +# return nothing +# end +# else +# # can't find source field - give up and return nothing +# return nothing +# end +# end + + +# """ +# Given a tape and getfield() operation, try to find a variable +# that was used to create that field. +# NOTE: This works only with getfield() on tuples +# """ +# function find_tuple_field_source_var(tape::Tape, getf_op::Call) +# getf_base_op = tape[getf_op.args[1]] +# if getf_base_op.fn == Base.indexed_iterate +# tuple_op = tape[getf_base_op.args[1]] +# # tuple_idx = tape[ind_it_op.args[2]].val +# tuple_idx = tape[getf_base_op.args[2]].val +# elseif getf_base_op.fn == __tuple__ +# tuple_op = getf_base_op +# tuple_idx = tape[getf_op.args[2]].val +# elseif getf_base_op.fn == namedtuple +# tuple_op = tape[getf_base_op.args[2]] +# tuple_fld = tape[getf_op.args[2]].val +# flds = fieldnames(typeof(getf_base_op.val)) +# tuple_idx = findfirst(x -> x == tuple_fld, flds) +# # should make recursive? +# else +# throw(AssertionError("Unexpected base op for __getfield__: $(getf_base_op.fn)")) +# end +# # @assert tuple_op isa Call && tuple_op.fn == __tuple__ +# src_var = tape[tuple_op.args[tuple_idx]] +# return src_var +# end diff --git a/src/utils.jl b/src/utils.jl index 0e47990..3edd001 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,5 +1,3 @@ -# isstruct(::Type{T}) where T = !isbitstype(T) && !(T <: AbstractArray) -# isstruct(obj) = !isbits(obj) && !isa(obj, AbstractArray) "Check if an object is of a struct type, i.e. not a number or array" isstruct(::Type{T}) where T = !isempty(fieldnames(T)) isstruct(obj) = isstruct(typeof(obj)) diff --git a/test/runtests.jl b/test/runtests.jl index 96aba04..c0665fe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,16 +2,17 @@ using Test using Yota using Yota: Tape, Input, Call, Constant, trace, play!, transform, binarize_ops using Yota: ∇mean, setfield_nested!, copy_with, simplegrad, remove_unused -using Yota: find_field_source_var, unwind_iterate, eliminate_common +using Yota: eliminate_common # unwind_iterate, find_field_source_var using Yota: unvcat, unhcat, uncat, ungetindex!, ungetindex +using Yota: gradcheck using CUDA +import ChainRulesCore: Composite, Zero CUDA.allowscalar(false) include("test_trace.jl") -include("gradcheck.jl") include("test_helpers.jl") include("test_simple.jl") include("test_grad.jl") diff --git a/test/test_grad.jl b/test/test_grad.jl index fcc0a84..bcbf4bd 100644 --- a/test/test_grad.jl +++ b/test/test_grad.jl @@ -29,6 +29,28 @@ end end +@testset "grad: iterate" begin + # iterate over tuple, e.g. for x in (1.0, 2.0, 3.0) + x = (1.0, 2.0, 3.0) + CT = Composite{typeof(x)} + @test grad(x -> iterate(x)[1], x)[2][1] == CT(1.0, Zero(), Zero()) + @test grad(x -> iterate(x, 2)[1], x)[2][1] == CT(Zero(), 1.0, Zero()) + @test grad(x -> iterate(x, 3)[1], x)[2][1] == CT(Zero(), Zero(), 1.0) + + # iterate over array, e.g. for x in [1.0, 2.0, 3.0] + x = [1.0, 2.0, 3.0] + # TODO (uncomment when scatter_add is fixed) + # @test grad(x -> iterate(x)[1], x)[2][1] == [1.0, 0, 0] + # @test grad(x -> iterate(x, 2)[1], x)[2][1] == [0, 1, 0] + # @test grad(x -> iterate(x, 3)[1], x)[2][1] == [0, 0, 1.0] + + x = (1:3) + @test !isdefined(grad(x -> iterate(x)[1], x)[2].gvars, 1) + @test !isdefined(grad(x -> iterate(x, 1)[1], x)[2].gvars, 1) + +end + + sum_bcast(x, y) = sum(x .+ y) @testset "special bcast" begin @@ -122,16 +144,11 @@ function add_points(x) return 2*l.p1.x + 5*l.p2.y end +# TODO: make a better test for constructors, not related to find_field_source_var @testset "grad: structs/new" begin # find_field_source_var _, tape = trace(add_points, rand()) - src_op = find_field_source_var(tape, tape[15]) - @test src_op.id == 1 - @test src_op.val isa Real - src_op = find_field_source_var(tape, tape[13]) - @test src_op.id == 3 - @test src_op.val isa Point @test grad(add_points, rand())[2][1] == 7 end diff --git a/test/test_update.jl b/test/test_update.jl index 7c08a86..a9dde06 100644 --- a/test/test_update.jl +++ b/test/test_update.jl @@ -1,5 +1,5 @@ import ChainRulesCore.Composite - + mutable struct A t::Array{Float64, 1} @@ -51,7 +51,7 @@ end @test x == -xo # with actual grad update - b = B(A(ones(4), 1.0), 1.0) + b = B(A(ones(4), 1.0), 1.0) _, g = grad(b -> sum(b.a.t) + b.a.s + b.s, b) update!(b, g[1], (x, gx) -> x .- 0.5gx) @test b.a.t == [0.5, 0.5, 0.5, 0.5]